import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import matplotlib.pylab as plt
from tabulate import tabulate
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

#将图像转换为张量并归一化
transform = transforms.Compose([
    transforms.ToTensor(),   #将图像转变为张量
    transforms.Normalize((0.1307,),(0.3081,)) #归一化
])
#加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) #训练集
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform) #测试集

#划分训练集和测试集，总长度的80%为训练集，20%为测试集
train_size = int(0.8*len(train_dataset))   #训练集，总长度的80%
val_size = len(train_dataset) - train_size   #测试集，总长度-训练集=测试集
train_dataset,val_dataset = torch.utils.data.random_split(train_dataset,[train_size,val_size])

#创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=32,shuffle=True)  #训练集，每个批次大小32个样品，每次迭代周期打乱数据
val_loader = DataLoader(val_dataset,batch_size=32,shuffle=False)   #测试集，每个批次大小32个样品，不打乱数据
test_loader = DataLoader(test_dataset,batch_size=32,shuffle=False)  #验证集，每次批次大小32个样品，不打乱数据

#定义多层感知机
class MLP(nn.Module):
    def __init__(self):
        super(MLP,self).__init__()
        self.flatten = nn.Flatten()  #将二维图像转变为一维
        self.fc1 = nn.Linear(28*28,20)  #输出张量尺寸28*28，有20个神经元（输入维度28*28，输出维度20）
        self.relu = nn.ReLU()   #单独层声明Relu激活函数
        self.fc2 = nn.Linear(20,10)  #输入维度20，输出维度10
        
    def forward(self,x):
        x = self.flatten(x)  #展平为一维
        x = self.fc1(x)   #通过第一个全连接层
        x = self.relu(x)  #应用激活函数
        x = self.fc2(x) #通过第二个激活函数
        return x   #返回x

#创建多层感知机模型,并移动到指定设备上
model = MLP().to(device)

if model.training == True:
    print("\n\n模型当前(默认)处于训练模式")

print(f"开始在{device}上训练: ")
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  #创建了一个交叉熵损失函数对象,用于衡量模型的预测结果与真实标签之间的差异
optimizer = optim.Adam(model.parameters(), lr=0.001)  #创建了一个Adam优化器对象,用于更新模型的参数

num_epochs = 20  #训练轮数为20
for epoch in range(num_epochs):   #一批一批的遍历训练数据集
    model.train()  #训练模式
    train_loss = 0.0
    train_acc = 0.0
    for images,labels in train_loader: #一批一批的遍历训练数据集
        images = images.to(device)  # 将图像张量移动到指定设备上
        labels = labels.to(device)  # 将标签张量移动到指定设备上
        outputs = model(images)   #载入训练图像
        loss = criterion(outputs, labels)  #计算损失，计算输出和真实标签之间的损失
        optimizer.zero_grad()   #梯度清零
        loss.backward()    #反向传播计算损失函数
        optimizer.step()  #优化器更新参数
        train_loss += loss.item()*images.size(0)   #计算样品总损失
        _,predicted = torch.max(outputs.data,1)
        train_acc += (predicted == labels).sum().item()  #预测正确样本总数，将Tensor转化Python数值，
    train_loss /= len(train_dataset)  #计算平均损失
    train_acc /= len(train_dataset)   # 计算平均训练准确率,除以训练集总样本数
    # 验证阶段
    model.eval()  # 将模型设置为评估模式
    val_loss = 0.0  # 初始化验证损失
    val_acc = 0.0  # 初始化验证准确率

    with torch.no_grad():  # 关闭梯度计算
        for images, labels in val_loader:  #一批一批的遍历验证数据集
            images = images.to(device)  # 将图像张量移动到指定设备上
            labels = labels.to(device)    # 将标签张量移动到指定设备上
            outputs = model(images)  #载入预测图像
            loss = criterion(outputs, labels)  #计算损失，计算输出和真实标签之间的损失

            val_loss += loss.item() * images.size(0)  #计算样品总损失
            _, predicted = torch.max(outputs.data, 1)  # 获取预测概率最大的类别索引
            val_acc += (predicted == labels).sum().item()  #预测正确样本总数，将Tensor转化Python数值，

    val_loss /= len(val_dataset)  # 计算平均验证损失,除以验证集总样本数
    val_acc /= len(val_dataset)  # 计算验证准确率、 
    #打印训练损失、训练准确率、验证损失和验证准确率
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
pytorch_path = 'pytorch_path.pth'
torch.save(model.state_dict(), pytorch_path)
print(f"模型参数已保存到: {pytorch_path}")
# 获取并打印文件大小
pytorch_size = os.path.getsize(pytorch_path)
print(f"模型参数文件大小: {pytorch_size} 字节")